package com.wisdom.system.common.mybatis;

import com.alibaba.druid.pool.DruidDataSource;
import com.wisdom.system.common.util.StringUtil;


import javax.sql.DataSource;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class SqlUtil {
    private static final String H = "#";
    private static final String S = "$";

    /**
     * mapper.xml中的取值方式为#{}时
     *
     * @param str like的查询条件
     * @return
     */
    public static String likeEscapeH(String str) {
        return likeEscapeZ(str, H, true, true);
    }

    /**
     * mapper.xml中的取值方式为${}时
     *
     * @param str like的查询条件
     * @return
     */
    public static String likeEscapeS(String str) {
        return likeEscapeZ(str, S, true, true);
    }

    /**
     * @param str   like的查询条件
     * @param type  mapper.xml中的取值方式，只能“#”或“$”
     * @param start 字符串前部是否拼接“%”
     * @param end   字符串尾部是否拼接“%”
     * @return
     */
    public static String likeEscapeZ(String str, String type, boolean start, boolean end) {
        if (str == null || str == "") {
            return null;
        }
        StringBuffer buffer = new StringBuffer();
        // 拼接顺序不能改变
        if (S.equals(type)) {
            buffer.append(" '");
        }
        if (start) {
            buffer.append("%");
        }
        int len = str.length();
        //注意："]"不能处理
        for (int i = 0; i < len; i++) {
            char c = str.charAt(i);
            switch (c) {
                case '\'':
                    if (S.equals(type)) {
                        buffer.append("''");// 单引号替换成两个单引号
                    } else {
                        buffer.append(c);
                    }
                    break;
                case '[':
                    buffer.append("[[]");
                    break;
//                case '_':
//                    buffer.append("[_]");
//                    break;
                case '%':
                    buffer.append("[%]");
                    break;
                case '^':
                    buffer.append("[^]");
                    break;
                case '!':
                    buffer.append("[!]");
                    break;
                default:
                    buffer.append(c);
            }
        }
        if (end) {
            buffer.append("%");
        }
        if (S.equals(type)) {
            buffer.append("' ");
        }
        return buffer.toString();
    }

    //对sql模板的#{param}参数进行替换
    public static String sqlConvert(String sql, Map<String, Object> param) {
        if (null != param) {
            Iterator<Map.Entry<String, Object>> iterator = param.entrySet().iterator();
            while (iterator.hasNext()) {
                Map.Entry<String, Object> entry = iterator.next();
                if (!StringUtil.isEmpty(sql)) {
                    //动态sql判断
                    // 例：<%=deptCode and a.FK_DEPT_CODE = #{deptCode} %>
                    //如果deptCode为null或为空  <%=deptCode %>包含的字符串不执行
                    if(null == entry.getValue() || StringUtil.isEmpty(String.valueOf(entry.getValue()))){
                        List<String> result = SqlUtil.subString(sql,"<%="+entry.getKey(), "%>");
                        for (String str : result) {
                            sql = sql.replace("<%="+entry.getKey()+str+"%>", "");
                        }
                        continue;
                    }else{
                        List<String> result = SqlUtil.subString(sql,"<%="+entry.getKey(), "%>");
                        for (String str : result) {
                            sql = sql.replace("<%="+entry.getKey()+str+"%>", str);
                        }
                    }
                    if (sql.indexOf("#{" + entry.getKey() + "}") != -1) {
                        //满足条件
                        sql = sql.replace("#{" + entry.getKey() + "}", "'" + entry.getValue() + "'");
                        continue;
                    }
                }
            }
        }
        List<String> result = SqlUtil.subString(sql,"<%=", "%>");
        for (String str : result) {
            sql = sql.replace("<%="+str+"%>", "");
        }
        return sql;
    }

    public static List<String> subString(String str, String strStart, String strEnd) {
        Pattern p= Pattern.compile(strStart+"(.*?)"+strEnd);
        Matcher matcher = p.matcher(str);
        List<String> list = new ArrayList<>();
        while (matcher.find()) {
            list.add(matcher.group(1));
        }
        return list;
    }

    public static String getTableSchema(DataSource dataSource) {
        String jdbcUrl = ((DruidDataSource) dataSource).getUrl();
        String[] jdbcUrlArray = jdbcUrl.split("\\?");
        String[] urlArray = jdbcUrlArray[0].split("/");
        String tableSchema = urlArray[urlArray.length - 1];
        return tableSchema;
    }
}
